-
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) #10980
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This pull request has merge conflicts that must be resolved before it can be |
7cd9a24
to
f506458
Compare
It seems like this PR is implementing ideas similar to those implemented in PR #9302 for the V0 engine. That PR created some issues that were addressed in PR #11898 and which may exist in the proposed V1 code. In particular, the proposed code currently does not properly handle the case when a seed value is provided for the parent request; the seed value is duplicated in child requests, leading to identical outputs in the child requests. The fix in #11898 was simply to move the copying of the Additionally, the proposed code for the V1 engine defines |
This pull request has merge conflicts that must be resolved before it can be |
3d5b962
to
bf3cfd0
Compare
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Thanks @afeldman-nm, appreciate the reply. One thing I'm wondering is whether matching n>1 throughput performance pre-0.6.4 is a priority for this change or follow-up work. It seems like fundamentally this new approach confers lower throughput. Relatedly, is there an architectural reason why this new approach is necessary in v1 or is it possible forking could be added later on if it makes a big difference? |
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Merged #13421 (LLMEngine support for parallel sampling) into this PR. |
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @afeldman-nm for the great work.
I am still reviewing the llm_engine.py
changes but wanted to post the comments I have so far. I feel like there may still be room to simplify/condense things a bit more, and possibly move more logic out of the main classes.
async for out in merge_async_iterators(*gens): | ||
yield out[1] # out[0] is index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
async for out in merge_async_iterators(*gens): | |
yield out[1] # out[0] is index | |
async for _, out in merge_async_iterators(*gens): | |
yield out |
@@ -241,6 +244,56 @@ async def generate( | |||
await self.abort(request_id) | |||
raise | |||
|
|||
async def _generate_parallel_sampling( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering whether it might be cleaner to also move this method into ParallelSamplingRequestManager, but it takes the base
generate` method as an additional arg.
) -> AsyncGenerator[RequestOutput, None]: | ||
"""Generate completions for parallel sampling requests.""" | ||
req_mgr = ParallelSamplingRequestManager(request_id, sampling_params) | ||
n = req_mgr.n |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nti: no need for this variable
# Aggregate generators for n child requests | ||
gens: List[AsyncGenerator[RequestOutput, None]] = [] | ||
for idx in range(n): | ||
c_sampling_params = req_mgr.get_child_sampling_params(idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about a single method that returns both the child sampling params and request id?
Child `sampling_params` instance. | ||
""" | ||
seed = self.sampling_params.seed | ||
if seed is None and self.cached_child_sampling_params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
if seed is None and self.cached_child_sampling_params: | |
if self.cached_child_sampling_params: |
# Note: will be sorted by index later | ||
self.request_output.outputs.append(new_completion) | ||
|
||
def _get_parent_request_output(self) -> RequestOutput: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggested rename:
def _get_parent_request_output(self) -> RequestOutput: | |
def _get_final_request_output(self) -> RequestOutput: |
from vllm.sampling_params import RequestOutputKind, SamplingParams | ||
|
||
|
||
class ParallelSamplingRequestManager: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT about renaming this to
class ParallelSamplingRequestManager: | |
class ParallelSamplingRequest: |
def _num_parallel_sampling_requests(self) -> int: | ||
return len(self.parallel_parent_reqs) | ||
|
||
def _num_parallel_sampling_child_requests(self) -> int: | ||
return len(self.parallel_child_reqs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO these helper methods are unnecessary (just adds more LOC)
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else | ||
self._do_reset_parallel_sampling) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else | |
self._do_reset_parallel_sampling) | |
self._do_reset_parallel_sampling |= num_parallel_reqs > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, looking again, this would be clearer:
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else | |
self._do_reset_parallel_sampling) | |
if self.parallel_parent_requests: | |
self._do_reset_parallel_sampling = True |
if num_parallel_reqs > 0 and len(request_outputs) > 0: | ||
# Process parallel sampling child request outputs | ||
return self._aggregate_parallel_sampling_outputs(request_outputs) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: redundant else
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else | ||
self._do_reset_parallel_sampling) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, looking again, this would be clearer:
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else | |
self._do_reset_parallel_sampling) | |
if self.parallel_parent_requests: | |
self._do_reset_parallel_sampling = True |
prompt_adapter_request=prompt_adapter_request, | ||
priority=priority) | ||
|
||
def _add_request_parallel_sampling( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar thought to the async case, perhaps this could be moved into the other class? (along with the state ... could have a subclass with that just for the sync case)
This PR adds support for parallel sampling to v1 AsyncLLM and LLMEngine.
Parallel sampling is implemented outside the engine. This does not impact the vLLM v0 parallel sampling implementation.
Design doc: https://docs.google.com/document/d/1_fvbHVCfexkPAj2Vx0Q0gNvE-b53WwtrLrD6NldgBFU/edit?usp=sharing (message me if you need access permissions)
A request with
n>1
will spawnn
requests withn=1
and aggregate their outputs in accordance with theoutput_kind
.If prefix caching is enabled, an initial warmup request withUpdate: The vLLM v1 engine can exploit APC when a prompt repeats within a batch, even if that prompt was not seen in a previous batch. Therefore, no warmup request is required.max_tokens=1
will be sent to the engine to fill the prefix cache.The abstractions are cleanest for async v1 engine because AsyncLLM presents a
generate()
method that handles adding requests and running the engine; parallel sampling is implemented by writing a wrapper which invokes this methodn
times against parallel sampling child requests.v1 LLMEngine presents
add_request()
andstep()
methods, so parallel sampling is implemented by writing anadd_request()
wrapper which branches parallel sampling requests inton
add_request()
calls for child requests, and then havingstep()
aggregate child request outputs into a parent request output.FIX #13419
FIX #13420